{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 文本分类实例"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step1 导入相关包"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "d:\\Miniconda\\envs\\geo\\lib\\site-packages\\numpy\\_distributor_init.py:30: UserWarning: loaded more than 1 DLL from .libs:\n",
      "d:\\Miniconda\\envs\\geo\\lib\\site-packages\\numpy\\.libs\\libopenblas.FB5AE2TYXYH2IJRDKGDGQ3XBKLKTF43H.gfortran-win_amd64.dll\n",
      "d:\\Miniconda\\envs\\geo\\lib\\site-packages\\numpy\\.libs\\libopenblas64__v0.3.21-gcc_10_3_0.dll\n",
      "  warnings.warn(\"loaded more than 1 DLL from .libs:\"\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments, BertTokenizer, BertForSequenceClassification\n",
    "from datasets import load_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "pretrain_model_dir = \"/data/models/huggingface/rbt3\"\n",
    "save_dir = '/data/logs/remote_ssh_demo'"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step2 加载数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['label', 'review'],\n",
       "    num_rows: 7765\n",
       "})"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset = load_dataset(\"csv\", data_files=\"./ChnSentiCorp_htl_all.csv\", split=\"train\")\n",
    "dataset = dataset.filter(lambda x: x[\"review\"] is not None)\n",
    "dataset"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step3 划分数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['label', 'review'],\n",
       "        num_rows: 6988\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['label', 'review'],\n",
       "        num_rows: 777\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "datasets = dataset.train_test_split(test_size=0.1)\n",
    "datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'label': 0, 'review': '很难相信通过携程会订到这么污秽不堪的酒店,浴室的水龙头和花洒已生锈,房间地毯是黑色的。'}"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "datasets[\"train\"][0]"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step4 数据集预处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "daf77b52ea35416f83cc99b2e6026c0e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/6988 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6e6bece673ae4f1aadfcc1ef72f6a92a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/777 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],\n",
       "        num_rows: 6988\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],\n",
       "        num_rows: 777\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "tokenizer = BertTokenizer.from_pretrained(pretrain_model_dir)\n",
    "\n",
    "def process_function(examples):\n",
    "    tokenized_examples = tokenizer(examples[\"review\"], max_length=128, truncation=True)\n",
    "    tokenized_examples[\"labels\"] = examples[\"label\"]\n",
    "    return tokenized_examples\n",
    "\n",
    "tokenized_datasets = datasets.map(process_function, batched=True, remove_columns=datasets[\"train\"].column_names)\n",
    "tokenized_datasets"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step5 创建模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /data/models/huggingface/rbt3 and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "model = BertForSequenceClassification.from_pretrained(pretrain_model_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "BertConfig {\n",
       "  \"_name_or_path\": \"/data/models/huggingface/rbt3\",\n",
       "  \"architectures\": [\n",
       "    \"BertForMaskedLM\"\n",
       "  ],\n",
       "  \"attention_probs_dropout_prob\": 0.1,\n",
       "  \"classifier_dropout\": null,\n",
       "  \"directionality\": \"bidi\",\n",
       "  \"hidden_act\": \"gelu\",\n",
       "  \"hidden_dropout_prob\": 0.1,\n",
       "  \"hidden_size\": 768,\n",
       "  \"initializer_range\": 0.02,\n",
       "  \"intermediate_size\": 3072,\n",
       "  \"layer_norm_eps\": 1e-12,\n",
       "  \"max_position_embeddings\": 512,\n",
       "  \"model_type\": \"bert\",\n",
       "  \"num_attention_heads\": 12,\n",
       "  \"num_hidden_layers\": 3,\n",
       "  \"output_past\": true,\n",
       "  \"pad_token_id\": 0,\n",
       "  \"pooler_fc_size\": 768,\n",
       "  \"pooler_num_attention_heads\": 12,\n",
       "  \"pooler_num_fc_layers\": 3,\n",
       "  \"pooler_size_per_head\": 128,\n",
       "  \"pooler_type\": \"first_token_transform\",\n",
       "  \"position_embedding_type\": \"absolute\",\n",
       "  \"transformers_version\": \"4.43.3\",\n",
       "  \"type_vocab_size\": 2,\n",
       "  \"use_cache\": true,\n",
       "  \"vocab_size\": 21128\n",
       "}"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.config"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step6 创建评估函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "path ./metric_accuracy.py None DownloadMode.REUSE_DATASET_IF_EXISTS\n",
      "path ./metric_f1.py None DownloadMode.REUSE_DATASET_IF_EXISTS\n"
     ]
    }
   ],
   "source": [
    "import evaluate\n",
    "\n",
    "acc_metric = evaluate.load(\"./metric_accuracy.py\")\n",
    "f1_metirc = evaluate.load(\"./metric_f1.py\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def eval_metric(eval_predict):\n",
    "    predictions, labels = eval_predict\n",
    "    predictions = predictions.argmax(axis=-1)\n",
    "    acc = acc_metric.compute(predictions=predictions, references=labels)\n",
    "    f1 = f1_metirc.compute(predictions=predictions, references=labels)\n",
    "    acc.update(f1)\n",
    "    return acc"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step7 创建TrainingArguments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "d:\\Miniconda\\envs\\geo\\lib\\site-packages\\transformers\\training_args.py:1525: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "TrainingArguments(\n",
       "_n_gpu=2,\n",
       "accelerator_config={'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None, 'use_configured_state': False},\n",
       "adafactor=False,\n",
       "adam_beta1=0.9,\n",
       "adam_beta2=0.999,\n",
       "adam_epsilon=1e-08,\n",
       "auto_find_batch_size=False,\n",
       "batch_eval_metrics=False,\n",
       "bf16=False,\n",
       "bf16_full_eval=False,\n",
       "data_seed=None,\n",
       "dataloader_drop_last=False,\n",
       "dataloader_num_workers=0,\n",
       "dataloader_persistent_workers=False,\n",
       "dataloader_pin_memory=True,\n",
       "dataloader_prefetch_factor=2,\n",
       "ddp_backend=None,\n",
       "ddp_broadcast_buffers=None,\n",
       "ddp_bucket_cap_mb=None,\n",
       "ddp_find_unused_parameters=None,\n",
       "ddp_timeout=1800,\n",
       "debug=[],\n",
       "deepspeed=None,\n",
       "disable_tqdm=False,\n",
       "dispatch_batches=None,\n",
       "do_eval=True,\n",
       "do_predict=False,\n",
       "do_train=False,\n",
       "eval_accumulation_steps=None,\n",
       "eval_delay=0,\n",
       "eval_do_concat_batches=True,\n",
       "eval_on_start=False,\n",
       "eval_steps=None,\n",
       "eval_strategy=epoch,\n",
       "eval_use_gather_object=False,\n",
       "evaluation_strategy=epoch,\n",
       "fp16=False,\n",
       "fp16_backend=auto,\n",
       "fp16_full_eval=False,\n",
       "fp16_opt_level=O1,\n",
       "fsdp=[],\n",
       "fsdp_config={'min_num_params': 0, 'xla': False, 'xla_fsdp_v2': False, 'xla_fsdp_grad_ckpt': False},\n",
       "fsdp_min_num_params=0,\n",
       "fsdp_transformer_layer_cls_to_wrap=None,\n",
       "full_determinism=False,\n",
       "gradient_accumulation_steps=1,\n",
       "gradient_checkpointing=False,\n",
       "gradient_checkpointing_kwargs=None,\n",
       "greater_is_better=True,\n",
       "group_by_length=False,\n",
       "half_precision_backend=auto,\n",
       "hub_always_push=False,\n",
       "hub_model_id=None,\n",
       "hub_private_repo=False,\n",
       "hub_strategy=every_save,\n",
       "hub_token=<HUB_TOKEN>,\n",
       "ignore_data_skip=False,\n",
       "include_inputs_for_metrics=False,\n",
       "include_num_input_tokens_seen=False,\n",
       "include_tokens_per_second=False,\n",
       "jit_mode_eval=False,\n",
       "label_names=None,\n",
       "label_smoothing_factor=0.0,\n",
       "learning_rate=2e-05,\n",
       "length_column_name=length,\n",
       "load_best_model_at_end=True,\n",
       "local_rank=0,\n",
       "log_level=passive,\n",
       "log_level_replica=warning,\n",
       "log_on_each_node=True,\n",
       "logging_dir=/data/logs/remote_ssh_demo\\runs\\Oct15_23-33-02_WIN-20240216IDH,\n",
       "logging_first_step=False,\n",
       "logging_nan_inf_filter=True,\n",
       "logging_steps=10,\n",
       "logging_strategy=steps,\n",
       "lr_scheduler_kwargs={},\n",
       "lr_scheduler_type=linear,\n",
       "max_grad_norm=1.0,\n",
       "max_steps=-1,\n",
       "metric_for_best_model=f1,\n",
       "mp_parameters=,\n",
       "neftune_noise_alpha=None,\n",
       "no_cuda=False,\n",
       "num_train_epochs=3,\n",
       "optim=adamw_torch,\n",
       "optim_args=None,\n",
       "optim_target_modules=None,\n",
       "output_dir=/data/logs/remote_ssh_demo,\n",
       "overwrite_output_dir=False,\n",
       "past_index=-1,\n",
       "per_device_eval_batch_size=128,\n",
       "per_device_train_batch_size=64,\n",
       "prediction_loss_only=False,\n",
       "push_to_hub=False,\n",
       "push_to_hub_model_id=None,\n",
       "push_to_hub_organization=None,\n",
       "push_to_hub_token=<PUSH_TO_HUB_TOKEN>,\n",
       "ray_scope=last,\n",
       "remove_unused_columns=True,\n",
       "report_to=['tensorboard'],\n",
       "restore_callback_states_from_checkpoint=False,\n",
       "resume_from_checkpoint=None,\n",
       "run_name=/data/logs/remote_ssh_demo,\n",
       "save_on_each_node=False,\n",
       "save_only_model=False,\n",
       "save_safetensors=True,\n",
       "save_steps=500,\n",
       "save_strategy=epoch,\n",
       "save_total_limit=3,\n",
       "seed=42,\n",
       "skip_memory_metrics=True,\n",
       "split_batches=None,\n",
       "tf32=None,\n",
       "torch_compile=False,\n",
       "torch_compile_backend=None,\n",
       "torch_compile_mode=None,\n",
       "torch_empty_cache_steps=None,\n",
       "torchdynamo=None,\n",
       "tpu_metrics_debug=False,\n",
       "tpu_num_cores=None,\n",
       "use_cpu=False,\n",
       "use_ipex=False,\n",
       "use_legacy_prediction_loop=False,\n",
       "use_mps_device=False,\n",
       "warmup_ratio=0.0,\n",
       "warmup_steps=0,\n",
       "weight_decay=0.01,\n",
       ")"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_args = TrainingArguments(output_dir=save_dir,      # 输出文件夹\n",
    "                               num_train_epochs=3,\n",
    "                               per_device_train_batch_size=64,  # 训练时的batch_size\n",
    "                               per_device_eval_batch_size=128,  # 验证时的batch_size\n",
    "                               logging_steps=10,                # log 打印的频率\n",
    "                               evaluation_strategy=\"epoch\",     # 评估策略\n",
    "                               save_strategy=\"epoch\",           # 保存策略\n",
    "                               save_total_limit=3,              # 最大保存数\n",
    "                               learning_rate=2e-5,              # 学习率\n",
    "                               weight_decay=0.01,               # weight_decay\n",
    "                               metric_for_best_model=\"f1\",      # 设定评估指标\n",
    "                               load_best_model_at_end=True)     # 训练完成后加载最优模型\n",
    "train_args"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step8 创建Trainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import DataCollatorWithPadding\n",
    "trainer = Trainer(model=model, \n",
    "                  args=train_args, \n",
    "                  train_dataset=tokenized_datasets[\"train\"], \n",
    "                  eval_dataset=tokenized_datasets[\"test\"], \n",
    "                  data_collator=DataCollatorWithPadding(tokenizer=tokenizer),\n",
    "                  compute_metrics=eval_metric)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step9 模型训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "104b61e299ba4380890fd598190724c9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/165 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "d:\\Miniconda\\envs\\geo\\lib\\site-packages\\torch\\nn\\parallel\\_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "d:\\Miniconda\\envs\\geo\\lib\\site-packages\\torch\\cuda\\nccl.py:15: UserWarning: PyTorch is not compiled with NCCL support\n",
      "  warnings.warn('PyTorch is not compiled with NCCL support')\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 0.638, 'grad_norm': 1.9000219106674194, 'learning_rate': 1.8787878787878792e-05, 'epoch': 0.18}\n",
      "{'loss': 0.57, 'grad_norm': 2.2092020511627197, 'learning_rate': 1.7575757575757576e-05, 'epoch': 0.36}\n",
      "{'loss': 0.456, 'grad_norm': 1.2345517873764038, 'learning_rate': 1.6363636363636366e-05, 'epoch': 0.55}\n",
      "{'loss': 0.3877, 'grad_norm': 1.389718770980835, 'learning_rate': 1.5151515151515153e-05, 'epoch': 0.73}\n",
      "{'loss': 0.3571, 'grad_norm': 4.076654434204102, 'learning_rate': 1.3939393939393942e-05, 'epoch': 0.91}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c1ca8268788340c0ad1435735fe32a01",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 0.30178898572921753, 'eval_accuracy': 0.8687258687258688, 'eval_f1': 0.9019230769230768, 'eval_runtime': 0.6489, 'eval_samples_per_second': 1197.426, 'eval_steps_per_second': 6.164, 'epoch': 1.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "d:\\Miniconda\\envs\\geo\\lib\\site-packages\\torch\\nn\\parallel\\_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "d:\\Miniconda\\envs\\geo\\lib\\site-packages\\torch\\cuda\\nccl.py:15: UserWarning: PyTorch is not compiled with NCCL support\n",
      "  warnings.warn('PyTorch is not compiled with NCCL support')\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 0.378, 'grad_norm': 2.10211181640625, 'learning_rate': 1.2727272727272728e-05, 'epoch': 1.09}\n",
      "{'loss': 0.3238, 'grad_norm': 3.0558106899261475, 'learning_rate': 1.1515151515151517e-05, 'epoch': 1.27}\n",
      "{'loss': 0.3033, 'grad_norm': 1.8828248977661133, 'learning_rate': 1.0303030303030304e-05, 'epoch': 1.45}\n",
      "{'loss': 0.309, 'grad_norm': 1.7313110828399658, 'learning_rate': 9.090909090909091e-06, 'epoch': 1.64}\n",
      "{'loss': 0.2833, 'grad_norm': 1.7349121570587158, 'learning_rate': 7.87878787878788e-06, 'epoch': 1.82}\n",
      "{'loss': 0.2887, 'grad_norm': 2.0048069953918457, 'learning_rate': 6.666666666666667e-06, 'epoch': 2.0}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e087a3d310dc4c43a0810e2c92a8f27c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 0.25066012144088745, 'eval_accuracy': 0.8957528957528957, 'eval_f1': 0.9240862230552953, 'eval_runtime': 0.6674, 'eval_samples_per_second': 1164.171, 'eval_steps_per_second': 5.993, 'epoch': 2.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "d:\\Miniconda\\envs\\geo\\lib\\site-packages\\torch\\nn\\parallel\\_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "d:\\Miniconda\\envs\\geo\\lib\\site-packages\\torch\\cuda\\nccl.py:15: UserWarning: PyTorch is not compiled with NCCL support\n",
      "  warnings.warn('PyTorch is not compiled with NCCL support')\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 0.2661, 'grad_norm': 2.6256706714630127, 'learning_rate': 5.4545454545454545e-06, 'epoch': 2.18}\n",
      "{'loss': 0.265, 'grad_norm': 3.7652876377105713, 'learning_rate': 4.242424242424243e-06, 'epoch': 2.36}\n",
      "{'loss': 0.2569, 'grad_norm': 4.263049125671387, 'learning_rate': 3.0303030303030305e-06, 'epoch': 2.55}\n",
      "{'loss': 0.2889, 'grad_norm': 2.4037749767303467, 'learning_rate': 1.8181818181818183e-06, 'epoch': 2.73}\n",
      "{'loss': 0.2806, 'grad_norm': 1.514642596244812, 'learning_rate': 6.060606060606061e-07, 'epoch': 2.91}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "d:\\Miniconda\\envs\\geo\\lib\\site-packages\\torch\\nn\\parallel\\_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "db654b0a2568439f83defa863c85a82c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 0.24281589686870575, 'eval_accuracy': 0.9021879021879022, 'eval_f1': 0.929368029739777, 'eval_runtime': 0.6342, 'eval_samples_per_second': 1225.231, 'eval_steps_per_second': 6.307, 'epoch': 3.0}\n",
      "{'train_runtime': 56.301, 'train_samples_per_second': 372.356, 'train_steps_per_second': 2.931, 'train_loss': 0.3492331447023334, 'epoch': 3.0}\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=165, training_loss=0.3492331447023334, metrics={'train_runtime': 56.301, 'train_samples_per_second': 372.356, 'train_steps_per_second': 2.931, 'total_flos': 351909933963264.0, 'train_loss': 0.3492331447023334, 'epoch': 3.0})"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step10 模型评估"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "d:\\Miniconda\\envs\\geo\\lib\\site-packages\\torch\\nn\\parallel\\_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b403701423b8477e9ad7f52be32573e3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "{'eval_loss': 0.24281589686870575,\n",
       " 'eval_accuracy': 0.9021879021879022,\n",
       " 'eval_f1': 0.929368029739777,\n",
       " 'eval_runtime': 0.6634,\n",
       " 'eval_samples_per_second': 1171.23,\n",
       " 'eval_steps_per_second': 6.029,\n",
       " 'epoch': 3.0}"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.evaluate(tokenized_datasets[\"test\"])"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step11 模型预测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "d:\\Miniconda\\envs\\geo\\lib\\site-packages\\torch\\nn\\parallel\\_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b6c2c870f47843d3b440a2228407aa1d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "PredictionOutput(predictions=array([[ 0.39005217, -0.5959798 ],\n",
       "       [-2.1953855 ,  2.4256752 ],\n",
       "       [ 1.2242132 , -1.4175045 ],\n",
       "       ...,\n",
       "       [-0.02434846, -0.05769343],\n",
       "       [-0.6431727 ,  1.1999712 ],\n",
       "       [-1.9546117 ,  2.2896514 ]], dtype=float32), label_ids=array([0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1,\n",
       "       1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1,\n",
       "       1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0,\n",
       "       1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1,\n",
       "       0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "       1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0,\n",
       "       1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1,\n",
       "       1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1,\n",
       "       0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0,\n",
       "       1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1,\n",
       "       1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1,\n",
       "       1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1,\n",
       "       0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0,\n",
       "       0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0,\n",
       "       1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1,\n",
       "       1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0,\n",
       "       1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0,\n",
       "       1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1,\n",
       "       1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1,\n",
       "       0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0,\n",
       "       0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1,\n",
       "       0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1,\n",
       "       0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0,\n",
       "       1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1,\n",
       "       1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1,\n",
       "       0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1,\n",
       "       1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1,\n",
       "       1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0,\n",
       "       1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1,\n",
       "       1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1,\n",
       "       1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1,\n",
       "       0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1,\n",
       "       1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1,\n",
       "       1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0,\n",
       "       1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 1,\n",
       "       0, 1, 0, 0, 1, 1, 1], dtype=int64), metrics={'test_loss': 0.24281589686870575, 'test_accuracy': 0.9021879021879022, 'test_f1': 0.929368029739777, 'test_runtime': 0.6313, 'test_samples_per_second': 1230.885, 'test_steps_per_second': 6.337})"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.predict(tokenized_datasets[\"test\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import pipeline\n",
    "\n",
    "id2_label = {\n",
    "    \"0\": \"NEGATIVE\",\n",
    "    \"1\": \"POSITIVE\"\n",
    "  },\n",
    "\n",
    "model.config.id2label = id2_label\n",
    "pipe = pipeline(\"text-classification\", model=model, tokenizer=tokenizer, device=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "ename": "IndexError",
     "evalue": "tuple index out of range",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mIndexError\u001b[0m                                Traceback (most recent call last)",
      "Cell \u001b[1;32mIn[17], line 2\u001b[0m\n\u001b[0;32m      1\u001b[0m sen \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m我觉得不错！\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m----> 2\u001b[0m \u001b[43mpipe\u001b[49m\u001b[43m(\u001b[49m\u001b[43msen\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32md:\\Miniconda\\envs\\geo\\lib\\site-packages\\transformers\\pipelines\\text_classification.py:156\u001b[0m, in \u001b[0;36mTextClassificationPipeline.__call__\u001b[1;34m(self, inputs, **kwargs)\u001b[0m\n\u001b[0;32m    122\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[0;32m    123\u001b[0m \u001b[38;5;124;03mClassify the text(s) given as inputs.\u001b[39;00m\n\u001b[0;32m    124\u001b[0m \n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m    153\u001b[0m \u001b[38;5;124;03m    If `top_k` is used, one such dictionary is returned per label.\u001b[39;00m\n\u001b[0;32m    154\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[0;32m    155\u001b[0m inputs \u001b[38;5;241m=\u001b[39m (inputs,)\n\u001b[1;32m--> 156\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__call__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    157\u001b[0m \u001b[38;5;66;03m# TODO try and retrieve it in a nicer way from _sanitize_parameters.\u001b[39;00m\n\u001b[0;32m    158\u001b[0m _legacy \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtop_k\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m kwargs\n",
      "File \u001b[1;32md:\\Miniconda\\envs\\geo\\lib\\site-packages\\transformers\\pipelines\\base.py:1254\u001b[0m, in \u001b[0;36mPipeline.__call__\u001b[1;34m(self, inputs, num_workers, batch_size, *args, **kwargs)\u001b[0m\n\u001b[0;32m   1246\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mnext\u001b[39m(\n\u001b[0;32m   1247\u001b[0m         \u001b[38;5;28miter\u001b[39m(\n\u001b[0;32m   1248\u001b[0m             \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_iterator(\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m   1251\u001b[0m         )\n\u001b[0;32m   1252\u001b[0m     )\n\u001b[0;32m   1253\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1254\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_single\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpreprocess_params\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mforward_params\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpostprocess_params\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32md:\\Miniconda\\envs\\geo\\lib\\site-packages\\transformers\\pipelines\\base.py:1262\u001b[0m, in \u001b[0;36mPipeline.run_single\u001b[1;34m(self, inputs, preprocess_params, forward_params, postprocess_params)\u001b[0m\n\u001b[0;32m   1260\u001b[0m model_inputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpreprocess(inputs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mpreprocess_params)\n\u001b[0;32m   1261\u001b[0m model_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mforward(model_inputs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mforward_params)\n\u001b[1;32m-> 1262\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpostprocess\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel_outputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mpostprocess_params\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m   1263\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m outputs\n",
      "File \u001b[1;32md:\\Miniconda\\envs\\geo\\lib\\site-packages\\transformers\\pipelines\\text_classification.py:222\u001b[0m, in \u001b[0;36mTextClassificationPipeline.postprocess\u001b[1;34m(self, model_outputs, function_to_apply, top_k, _legacy)\u001b[0m\n\u001b[0;32m    219\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUnrecognized `function_to_apply` argument: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfunction_to_apply\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m    221\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m top_k \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m1\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m _legacy:\n\u001b[1;32m--> 222\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlabel\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconfig\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mid2label\u001b[49m\u001b[43m[\u001b[49m\u001b[43mscores\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43margmax\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mitem\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m]\u001b[49m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mscore\u001b[39m\u001b[38;5;124m\"\u001b[39m: scores\u001b[38;5;241m.\u001b[39mmax()\u001b[38;5;241m.\u001b[39mitem()}\n\u001b[0;32m    224\u001b[0m dict_scores \u001b[38;5;241m=\u001b[39m [\n\u001b[0;32m    225\u001b[0m     {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlabel\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mid2label[i], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mscore\u001b[39m\u001b[38;5;124m\"\u001b[39m: score\u001b[38;5;241m.\u001b[39mitem()} \u001b[38;5;28;01mfor\u001b[39;00m i, score \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(scores)\n\u001b[0;32m    226\u001b[0m ]\n\u001b[0;32m    227\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m _legacy:\n",
      "\u001b[1;31mIndexError\u001b[0m: tuple index out of range"
     ]
    }
   ],
   "source": [
    "sen = \"我觉得不错！\"\n",
    "pipe(sen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "transformers",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.19"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
